##################################################################################
# Purpose: A class to fit multi-class and multi-label classifiers
# Author: Haohan Chen
##################################################################################

import os
import pandas as pd
import numpy as np
import scipy
from scipy.special import softmax
from matplotlib import pyplot as plt
import seaborn as sns
import shutil

from sklearn.metrics import (
    f1_score, hamming_loss, label_ranking_loss, label_ranking_average_precision_score, 
    coverage_error, precision_recall_curve, average_precision_score,
    roc_auc_score, auc, roc_curve, confusion_matrix, classification_report
    )

from simpletransformers.classification import ClassificationModel, ClassificationArgs
from simpletransformers.classification import MultiLabelClassificationModel, MultiLabelClassificationArgs



class classify_tweets:
    def __init__(self, transformer_spec, model_name, model_directory, model_type, n_epochs, data, random_seed, codebook, weighted):
        self.model_directory = model_directory
        self.model_name = model_name
        self.model_type = model_type
        self.n_epochs = n_epochs
        self.random_seed = random_seed
        self.d = data # Tuple, first - dataset of interest, second - data coded with embedded tweets.
        self.codebook = codebook
        self.weighted = weighted
        self.transformer_spec = transformer_spec


    def data_setup(self):
        d_tr = self.d[0]
        d_va = self.d[1]
        num_obs = d_tr.shape[0]

        if self.model_type == "Multi-class":
            labels_tr = np.array(d_tr.labels.tolist())
            labels_va = np.array(d_va.labels.tolist())
            if labels_tr.shape[1] == 1:
                num_label = 2
                self.codebook = self.codebook + ['Other']
                d_tr['labels'] = labels_tr
                d_va['labels'] = labels_va
            elif labels_tr.shape[1] > 1: # If beyond binary, add a "none selected" class
                label_none_tr = (np.sum(labels_tr, axis = 1) == 0) + 0
                label_none_va = (np.sum(labels_va, axis = 1) == 0) + 0
                if np.sum(label_none_tr) > 0: # If there is the "none of the above class"
                    labels_tr = np.hstack((labels_tr, label_none_tr.reshape(-1, 1)))
                    labels_va = np.hstack((labels_va, label_none_va.reshape(-1, 1)))
                    self.codebook = self.codebook + ['None of the above']    
                num_label = labels_tr.shape[1]
                d_tr['labels'] = np.argmax(labels_tr, axis = 1)
                d_va['labels'] = np.argmax(labels_va, axis = 1)

        elif self.model_type == "Multi-label":
            num_label = len(d_tr.labels.values[0])

        # Get weights
        if self.model_type == "Multi-class":
            weight = pd.DataFrame(len(d_tr) / (d_tr.labels.value_counts() * num_label)).sort_index()['count'].tolist()
        elif self.model_type == "Multi-label":
            freq_label = np.sum(np.array(d_tr.labels.tolist()), axis = 0)
            weight = num_obs / (num_label * freq_label)

        # Output
        self.num_label = num_label
        if self.weighted == True:
            self.weight = list(weight)
        else:
            self.weight = [1.] * self.num_label
        self.d_tr = d_tr
        self.d_va = d_va

    # # Process test data (the hold-out set that is not in the training)
    # def load_test_data(self, d_ts):
    #     d = d_ts
    #     if self.model_type == "Multi-class":
    #         labels = np.array(d.labels.tolist())
    #         if labels.shape[1] == 1:
    #             num_label = 2
    #             self.codebook = self.codebook + ['Other']
    #             d['labels'] = labels
    #         elif labels.shape[1] > 1: # If binary, don't do anything
    #             label_none = (np.sum(labels, axis = 1) == 0) + 0
    #             if np.sum(label_none) > 0: # If there is the "none of the above class"
    #                 labels = np.hstack((labels, label_none.reshape(-1, 1)))
    #                 self.codebook = self.codebook + ['None of the above']    
    #             num_label = labels.shape[1]
    #             d['labels'] = np.argmax(labels, axis = 1)
    #     elif self.model_type == "Multi-label":
    #         num_label = len(d.labels.values[0])

    #     # Check if the test data have the same number of columns as the training data
    #     # Raise error if not the same
    #     assert self.num_label == num_label
    #     self.d_ts = d

    # Re-sampling training data
    def data_resampling(self):
        pass

    # Evaluation metrics of multi-class classifier
    def _f1_binary(self, y_true, y_pred):
        score = f1_score(y_true, y_pred, average='binary')
        return score

    def _f1_macro_mc(self, y_true, y_pred):
        score = f1_score(y_true, y_pred, average='macro')
        # print(f"F1-macro: {score}")
        return score

    def _f1_micro_mc(self, y_true, y_pred):
        score = f1_score(y_true, y_pred, average='micro')
        # print(f"F1-micro: {score}")
        return score

    def _f1_weighted_mc(self, y_true, y_pred):
        score = f1_score(y_true, y_pred, average='weighted')
        # print(f"F1-weighted: {score}")
        return score

    def _f1_max_mc(self, y_true, y_pred):
        score = f1_score(y_true, y_pred, average=None)
        return score.max()

    def _f1_min_mc(self, y_true, y_pred):
        score = f1_score(y_true, y_pred, average=None)
        return score.min()

    def _prc_mc(self, y_true, y_pred):
        precision = [None] * self.num_label
        recall = [None] * self.num_label
        auc = [None] * self.num_label
        # Calculate precision and recall
        for i in range(self.num_label):
            precision[i], recall[i], _ = precision_recall_curve((y_true == i)+0, y_pred[:,i])
            average_precision[i] = average_precision_score((y_true == i)+0, y_pred[:,i])
            auc[i] = auc(recall[i], precision[i])        
        return (precision, recall, auc)

    # Evaluate: ROC AUC
    # https://scikit-learn.org/stable/modules/model_evaluation.html#roc-auc-multilabel
    def _auroc_micro(self, y_true, y_pred):
        if self.model_type == "Multi-label":
            score = roc_auc_score(y_true, y_pred, average = "micro")
        elif self.model_type == "Multi-class":
            score = roc_auc_score(y_true, y_pred, average = "micro", multi_class = "ovo")
        return score
    
    def _auroc_macro(self, y_true, y_pred):
        if self.model_type == "Multi-label":
            score = roc_auc_score(y_true, y_pred, average = "macro")
        elif self.model_type == "Multi-class":
            score = roc_auc_score(y_true, y_pred, average = "macro", multi_class = "ovo")
        return score

    def _auroc_weighted(self, y_true, y_pred):
        if self.model_type == "Multi-label":
            score = roc_auc_score(y_true, y_pred, average = "weighted")
        elif self.model_type == "Multi-class":
            score = roc_auc_score(y_true, y_pred, average = "weighted", multi_class = "ovo")
        return score

    # Evaluation metrics of multi-label classifier
    def _f1_macro_ml(self, y_true, y_score):
        y_pred = y_score > 0.5
        f1 = []
        for i in range(y_true.shape[1]):
            f1.append(f1_score(y_true[:,i], y_pred[:,i]))
        # print(f"F1 macro = {np.mean(f1)}")
        return np.mean(f1)

    def _f1_min_ml(self, y_true, y_score):
        y_pred = y_score > 0.5
        f1 = []
        for i in range(y_true.shape[1]):
            f1.append(f1_score(y_true[:,i], y_pred[:,i]))
        # print(f"F1 min = {np.min(f1)}")
        return np.min(f1)

    def _f1_max_ml(self, y_true, y_score):
        y_pred = y_score > 0.5
        f1 = []
        for i in range(y_true.shape[1]):
            f1.append(f1_score(y_true[:,i], y_pred[:,i]))
        return np.max(f1)

    def _hamming_loss_ml(self, y_true, y_score):
        y_pred = y_score > 0.5
        score = hamming_loss(y_true, y_pred)
        return score

    def _lrap_ml(self, y_true, y_score):
        score = label_ranking_average_precision_score(y_true, y_score)
        print(f"LRAP = {score}")
        return score

    def _label_ranking_loss_ml(self, y_true, y_score):
        score = label_ranking_loss(y_true, y_score)
        return score

    def _coverage_error_ml(self, y_true, y_score):
        score = coverage_error(y_true, y_score)
        return score

    def _ap_micro_ml(self, y_true, y_pred):
        score = average_precision_score(y_true, y_pred, average="micro")
        return score

    def _ap_macro_ml(self, y_true, y_pred):
        score = average_precision_score(y_true, y_pred, average="macro")
        return score
    
    def _ap_weighted_ml(self, y_true, y_pred):
        score = average_precision_score(y_true, y_pred, average="weighted")
        return score


    def setup_model(self, save_model_every_epoch = True, save_no_model = False):
        if self.model_type == "Multi-class":
            print("Setting up multi-class classifier")
            model_args = ClassificationArgs(
                manual_seed = self.random_seed,
                
                output_dir = os.path.join(self.model_directory, "output"),
                cache_dir = os.path.join(self.model_directory, "cache"),
                best_model_dir = os.path.join(self.model_directory, "best"),
                overwrite_output_dir = False,
                
                optimizer = "AdamW",
                learning_rate = 1e-6,
                adam_epsilon = 1e-8,
                polynomial_decay_schedule_lr_end = 1e-8,
                
                num_train_epochs = self.n_epochs, 
                
                evaluate_during_training = True,
                evaluate_during_training_verbose = False,
                
                use_early_stopping = False,
                # early_stopping_patience = 5,
                # early_stopping_metric = self._f1_macro_mc,
                # early_stopping_metric_minimize = False,
                # early_stopping_consider_epochs = True,

                save_eval_checkpoints = False,
                save_steps = -1,
                save_model_every_epoch = save_model_every_epoch,
                no_save = save_no_model,
                train_batch_size = 16,
                eval_batch_size = 16,

                use_multiprocessing=False,
                use_multiprocessing_for_evaluation=False
                )

            os.environ["TOKENIZERS_PARALLELISM"] = "false"

            self.model = ClassificationModel(
                self.transformer_spec[0], self.transformer_spec[1],
                num_labels = self.num_label,
                weight = self.weight,
                args = model_args
            )
        else:
            print("Setting up multi-label classifier")
            model_args = MultiLabelClassificationArgs(
                manual_seed = self.random_seed,
                
                output_dir = os.path.join(self.model_directory, "output"),
                cache_dir = os.path.join(self.model_directory, "cache"),
                best_model_dir = os.path.join(self.model_directory, "best"),
                overwrite_output_dir = False,
                
                optimizer = "AdamW",
                learning_rate = 5e-6,
                adam_epsilon = 1e-8,
                # polynomial_decay_schedule_lr_end = 1e-8,
                
                num_train_epochs = self.n_epochs, 
                
                evaluate_during_training = True,
                evaluate_during_training_verbose = False,
                
                use_early_stopping = False,
                # early_stopping_patience = 5,
                # early_stopping_metric = self._f1_macro_ml,
                # early_stopping_metric_minimize = False,
                # early_stopping_consider_epochs = True,
                
                save_eval_checkpoints = False,
                save_steps = -1,
                save_model_every_epoch = save_model_every_epoch,
                no_save = save_no_model,
                train_batch_size = 16,
                eval_batch_size = 8,

                use_multiprocessing=False,
                use_multiprocessing_for_evaluation=False
                )
            
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
            # https://github.com/ThilinaRajapakse/simpletransformers/issues/638

            self.model = MultiLabelClassificationModel(
                self.transformer_spec[0], self.transformer_spec[1], 
                num_labels = self.num_label,
                args = model_args,
                pos_weight = self.weight
            )


    def train_model(self):
        if self.model_type == "Multi-class":
            if self.num_label > 2:
                self.model.train_model(
                    train_df = self.d_tr, eval_df = self.d_va,
                    f1_macro = self._f1_macro_mc,
                    f1_micro = self._f1_micro_mc,
                    f1_weighted = self._f1_weighted_mc,
                    f1_max = self._f1_max_mc,
                    f1_min = self._f1_min_mc

                    # ap_micro = self._ap_micro_ml,
                    # ap_macro = self._ap_macro_ml,
                    # ap_weighted = self._ap_weighted_ml,

                    # auroc_micro = self._auroc_micro,
                    # auroc_macro = self._auroc_macro,
                    # auroc_weighted = self._auroc_weighted
                    )
            else:
                self.model.train_model(
                    train_df = self.d_tr, eval_df = self.d_va,
                    f1_binary = self._f1_binary,
                    f1_macro = self._f1_macro_mc,
                    f1_micro = self._f1_micro_mc,
                    f1_weighted = self._f1_weighted_mc,
                    f1_max = self._f1_max_mc,
                    f1_min = self._f1_min_mc
                    )

        elif self.model_type == "Multi-label":
            self.model.train_model(
                train_df = self.d_tr, eval_df = self.d_va,
                f1_macro = self._f1_macro_ml,
                f1_max = self._f1_max_ml,
                f1_min = self._f1_min_ml,

                ap_micro = self._ap_micro_ml,
                ap_macro = self._ap_macro_ml,
                ap_weighted = self._ap_weighted_ml,

                auroc_micro = self._auroc_micro,
                auroc_macro = self._auroc_macro,
                auroc_weighted = self._auroc_weighted,

                hamming_loss = self._hamming_loss_ml,
                lrap = self._lrap_ml,
                label_ranking_loss = self._label_ranking_loss_ml,
                coverage_error = self._coverage_error_ml
                )

    # Remove models that are not top performers according to the specified variables.
    def remove_bad_model(self, eval_metrics):
        model_directory = self.model_directory
        output_directory = os.path.join(model_directory, 'output')

        path_progress_score = os.path.join(output_directory, 'training_progress_scores.csv')

        model_eval_sum = pd.read_csv(path_progress_score)
        model_eval_sum = model_eval_sum.query('global_step % 2000 != 0').reset_index(drop = True)
        model_eval_sum['fname'] = [f'checkpoint-{model_eval_sum.global_step[i]}-epoch-{i+1}' for i in range(model_eval_sum.shape[0])]

        fname_keep = []
        for metric in eval_metrics:
            cond = model_eval_sum[metric] == model_eval_sum[metric].max()
            model_eval_sum[cond]
            fname_keep.append(model_eval_sum[cond].fname.values[0])

        fname_keep = list(set(fname_keep))
        model_dirs = [dir for dir in os.listdir(output_directory) if os.path.isdir(os.path.join(output_directory, dir))]

        for dir in model_dirs:
            if os.path.isdir(os.path.join(output_directory, dir)):
                if dir not in fname_keep:
                    shutil.rmtree(os.path.join(output_directory, dir))
                    print(f"Removed {dir}")
            else:
                print(f"Directory does not exist: {dir}")


    def _load_model(self, load_model = True, eval_metric = 'f1_macro'):
        path_progress_score = os.path.join(self.model_directory, "output", 'training_progress_scores.csv')

        model_eval_sum = pd.read_csv(path_progress_score)
        model_eval_sum = model_eval_sum.query('global_step % 2000 != 0').reset_index(drop = True)

        model_eval_sum['fname'] = [f'checkpoint-{model_eval_sum.global_step[i]}-epoch-{i+1}' for i in range(model_eval_sum.shape[0])]
        model_eval_sum = model_eval_sum.sort_values(by = [eval_metric], ascending = False)

        best_model = model_eval_sum.iloc[0, :]
        path_best_model = os.path.join(self.model_directory, "output", best_model.fname)

        # Output results
        self.table_training_progress = model_eval_sum
        self.model_name = f"{best_model.fname} (best {eval_metric})"
        print(best_model)

        if load_model:
            if self.model_type == "Multi-class":
                model = ClassificationModel('roberta', path_best_model)
            elif self.model_type == "Multi-label":
                model = MultiLabelClassificationModel('roberta', path_best_model)
            print(f"Model '{best_model.fname}' with best {eval_metric} is loaded.")
            self.model = model

    def evaluate_best_model(self, d_ts, evaluate_model = True, best_model_metric = 'f1_macro'):
        # Load data needed
        self._load_model(load_model = evaluate_model, eval_metric = best_model_metric)

        eval_dir = os.path.join(self.model_directory, 'eval', self.model_name)

        if not os.path.isdir(eval_dir):
            os.makedirs(eval_dir)

        if evaluate_model:
            _, model_outputs, _ = self.model.eval_model(d_ts)

            if self.model_type == "Multi-class":
                y_pred_prob = softmax(model_outputs, axis=1)
                y_actual = d_ts.labels.values
            elif self.model_type == "Multi-label":
                y_pred_prob = model_outputs
                y_actual = np.array(d_ts.labels.tolist())

            pd.DataFrame(y_pred_prob).to_csv(os.path.join(eval_dir, "y_pred_prob.csv"), header = False, index = False)
            pd.DataFrame(y_actual).to_csv(os.path.join(eval_dir, "y_actual.csv"), header = False, index = False)
        else:
            y_pred_prob = pd.read_csv(os.path.join(eval_dir, "y_pred_prob.csv"), header = None).values
            y_actual = pd.read_csv(os.path.join(eval_dir, "y_actual.csv"), header = None).values

        # Get predicted values
        if self.model_type == "Multi-class":
            y_pred = np.argmax(y_pred_prob, axis = 1)
        elif self.model_type == "Multi-label":
            y_pred = (y_pred_prob >= 0.5) + 0

        self.y_pred_prob = y_pred_prob
        self.y_pred = y_pred
        self.y_actual = y_actual
        self.num_label = y_actual.shape[1]


    def visualize_train(self):
        # Track evaluation metrics (change over time)
        pass

    def visualize_evaluation(self, figure_title, print_fig):
        eval_dir = os.path.join(self.model_directory, 'eval', self.model_name)

        ## Confusion metrics
        ### Normalize
        cm_label = ['No', 'Yes']

        for i in range(self.num_label):
            cm  = confusion_matrix(self.y_actual[:, i], self.y_pred[:, i], normalize = 'true')
            plt.figure(figsize = (5, 4))
            sns.heatmap(cm, annot = True, vmin = 0, vmax = 1, fmt = ".2f", cmap = "Blues", xticklabels = cm_label, yticklabels = cm_label)
            plt.xlabel('Predicted (Machine-labeled)')
            plt.ylabel('Actual (Human-labeled)')
            plt.title(f'Confusion Matrix (norm)\n{self.codebook[i]}', size = 'medium')
            plt.savefig(os.path.join(eval_dir, f"confusion_matrix_norm_{i}.pdf"))
            if print_fig:
                plt.show()

        # Unnormalized
        for i in range(self.num_label):
            cm  = confusion_matrix(self.y_actual[:, i], self.y_pred[:, i])
            plt.figure(figsize = (5, 4))
            sns.heatmap(cm, annot = True, fmt = "d", cmap = "Blues", xticklabels = cm_label, yticklabels = cm_label)
            plt.xlabel('Predicted (Machine-labeled)')
            plt.ylabel('Actual (Human-labeled)')
            plt.title(f'Confusion Matrix\n{self.codebook[i]}', size = 'medium')
            plt.savefig(os.path.join(eval_dir, f"confusion_matrix_{i}.pdf"))
            if print_fig:
                plt.show()

        # Get precision and recall
        precision = dict()
        recall = dict()
        average_precision = dict()

        for i in range(self.num_label):
            precision[i], recall[i], _ = precision_recall_curve(self.y_actual[:,i], self.y_pred_prob[:,i])
            average_precision[i] = average_precision_score(self.y_actual[:, i], self.y_pred_prob[:, i], pos_label = 1)

        ## Precision-Recall curve
        plt.figure(figsize=(10,8))
        lw = 2
        labels = self.codebook

        for i in range(self.num_label):
            plt.plot(recall[i], precision[i], 
                    lw = 2, 
                    label = f'{labels[i]} (AUC=%0.2f, AP=%0.2f)' % (auc(recall[i], precision[i]), average_precision[i]))
            plt.xlim([-0.02, 1.0])
            plt.ylim([0.0, 1.02])
            plt.xlabel('Recall', size = 'x-large')
            plt.ylabel('Precision', size = 'x-large')
            plt.title(f'{figure_title}\nPrecision-Recall Curve', size = 'xx-large')
            plt.legend(loc = 'lower left')
        # Plot iso-f1 curve
        f_scores = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
        for f_score in f_scores:
            x = np.linspace(0.005, 1)
            y = f_score * x / (2 * x - f_score)
            plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2)
            plt.annotate('F1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02))
        plt.savefig(os.path.join(eval_dir, "pr_curve.pdf"))
        if print_fig:
            plt.show()

        ## ROC curve
        plt.figure(figsize=(10,8))
        lw = 2
        plt.plot([0,1], [0,1], color = 'gray', lw = lw, linestyle = '--')
        for i in range(self.num_label):
            fpr, tpr, _ = roc_curve(self.y_actual[:,i], self.y_pred_prob[:,i])
            auroc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw = 2, label = f'{labels[i]} (AUC = %0.2f)' % auroc)
            plt.xlim([-0.02, 1.0])
            plt.ylim([0.0, 1.02])
            plt.xlabel('False Positive Rate', size = 'x-large')
            plt.ylabel('True Positive Rate', size = 'x-large')
            plt.title(f'{figure_title}\nReceiver Operating Characteristic', size = 'xx-large')
            plt.legend(loc = 'lower right')
        plt.savefig(os.path.join(eval_dir, "roc_curve.pdf"))
        if print_fig:
            plt.show()

        # Classification report
        colnames = {str(i): self.codebook[i] for i in range(len(self.codebook))}
        report =  classification_report(self.y_actual, self.y_pred, output_dict = True)
        report = pd.DataFrame(report).rename(columns = colnames).transpose()
        report = report.round(2)
        report = report.astype({'support': "int32"})

        report.to_csv(os.path.join(eval_dir, "classification_report.csv"))
        with open(os.path.join(eval_dir, "classification_report.tex"), 'w') as f:
            report.to_latex(buf = f, index = True)


    def apply_model(self):
        pass
